"""Utilities for getting models."""

import numpy as np
from models.fcn import FCN
from models.nin import NIN
from .resnet import *


def get_model(config):
    model_spec = config["model"]
    name = model_spec["name"]
    if name == "nin":
        model = NIN(
            n_classes=model_spec["n_classes"],
            n_channel=model_spec["n_channel"],
            depth=model_spec["depth"],
            width=model_spec["width"],
            batch_norm=model_spec["batch_norm"],
            dropout=model_spec["dropout"],
        )
    elif name == "fcn":
        model = FCN(
            input_size=model_spec["input_size"],
            n_classes=model_spec["n_classes"],
            depth=model_spec["depth"],
            width=model_spec["width"],
        )
    elif name == "resnet18":
        model = ResNet(
            BasicBlock,
            [2, 2, 2, 2],
            model_spec["width"],
            num_classes=model_spec["n_classes"],
        )
    elif name == "resnet34":
        model = ResNet(
            BasicBlock,
            [3, 4, 6, 3],
            model_spec["width"],
            num_classes=model_spec["n_classes"],
        )
    else:
        raise ValueError("Unrecognized model: {}".format(name))
    return model


def get_model_ssl(config):
    model_spec = config["model"]
    name = model_spec["name"]
    if name == "resnet18":
        base_model = ResNet(
            BasicBlock,
            [2, 2, 2, 2],
            model_spec["width"],
            num_classes=model_spec["n_classes"],
        )
        model = ResNetSimCLR(base_model, model_spec["n_classes"])
    elif name == "resnet34":
        base_model = ResNet(
            BasicBlock,
            [3, 4, 6, 3],
            model_spec["width"],
            num_classes=model_spec["n_classes"],
        )
        model = ResNetSimCLR(base_model, model_spec["n_classes"])
    else:
        raise ValueError("Unrecognized model: {}".format(name))
    return model
